# Copyright (c) 2025 The HuggingFace Team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0 
#
# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
#
# Original file was released under Apache License 2.0, with the full license text
# available at https://github.com/huggingface/finetrainers/blob/main/LICENSE.
#
# This modified file is released under the same license.

import contextlib
import copy
import pathlib
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from .logging import get_logger
from .utils import Timer, TimerDevice


logger = get_logger()


class BaseTracker:
    r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging."""

    def __init__(self):
        self._timed_metrics = {}

    @contextlib.contextmanager
    def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
        r"""Context manager to track time for a specific operation."""
        timer = Timer(name, device, device_sync)
        timer.start()
        yield timer
        timer.end()
        elapsed_time = timer.elapsed_time
        if name in self._timed_metrics:
            # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
            self._timed_metrics[name] += elapsed_time
        else:
            self._timed_metrics[name] = elapsed_time

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        pass

    def finish(self) -> None:
        pass


class DummyTracker(BaseTracker):
    def __init__(self):
        super().__init__()

    def log(self, *args, **kwargs):
        pass

    def finish(self) -> None:
        pass


class WandbTracker(BaseTracker):
    r"""Logger implementation for Weights & Biases."""

    def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None, runs_name: Optional[str] = None) -> None:
        super().__init__()

        import wandb

        self.wandb = wandb

        # WandB does not create a directory if it does not exist and instead starts using the system temp directory.
        pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)

        self.run = wandb.init(project=experiment_name, dir=log_dir, config=config, name=runs_name)
        logger.info("WandB logging enabled")

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        metrics = {**self._timed_metrics, **metrics}
        self.run.log(metrics, step=step)
        self._timed_metrics = {}

    def finish(self) -> None:
        self.run.finish()


class SequentialTracker(BaseTracker):
    r"""Sequential tracker that logs to multiple trackers in sequence."""

    def __init__(self, trackers: List[BaseTracker]) -> None:
        super().__init__()
        self.trackers = trackers

    @contextlib.contextmanager
    def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
        r"""Context manager to track time for a specific operation."""
        timer = Timer(name, device, device_sync)
        timer.start()
        yield timer
        timer.end()
        elapsed_time = timer.elapsed_time
        if name in self._timed_metrics:
            # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
            self._timed_metrics[name] += elapsed_time
        else:
            self._timed_metrics[name] = elapsed_time
        for tracker in self.trackers:
            tracker._timed_metrics = copy.deepcopy(self._timed_metrics)

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        for tracker in self.trackers:
            tracker.log(metrics, step)
        self._timed_metrics = {}

    def finish(self) -> None:
        for tracker in self.trackers:
            tracker.finish()


class Trackers(str, Enum):
    r"""Enum for supported trackers."""

    NONE = "none"
    WANDB = "wandb"


_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()]


def initialize_trackers(
    trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str, runs_name: str
) -> Union[BaseTracker, SequentialTracker]:
    r"""Initialize loggers based on the provided configuration."""

    logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}")

    if len(trackers) == 0:
        return BaseTracker()

    if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)):
        raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}")

    tracker_instances = []
    for tracker_name in set(trackers):
        if tracker_name == Trackers.NONE:
            tracker = BaseTracker()
        elif tracker_name == Trackers.WANDB:
            tracker = WandbTracker(experiment_name, log_dir, config, runs_name)
        tracker_instances.append(tracker)

    tracker = SequentialTracker(tracker_instances)
    return tracker


TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker]
